// SPDX-FileCopyrightText: 2022 Demerzel Solutions Limited
// SPDX-License-Identifier: LGPL-3.0-only

using System;
using FluentAssertions;
using Nethermind.Core;
using Nethermind.Core.Crypto;
using Nethermind.Core.Extensions;
using Nethermind.Core.Resettables;
using Nethermind.Core.Test;
using Nethermind.Core.Test.Builders;
using Nethermind.Db;
using Nethermind.Specs.Forks;
using Nethermind.Logging;
using Nethermind.Evm.State;
using Nethermind.Int256;
using Nethermind.State;
using NSubstitute;
using NUnit.Framework;

namespace Nethermind.Store.Test;

[Parallelizable(ParallelScope.All)]
public class StorageProviderTests
{
    private static readonly ILogManager LogManager = LimboLogs.Instance;

    private readonly byte[][] _values =
    [
        [0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9],
        [10],
        [11],
        [12],
    ];

    [Test]
    public void Empty_commit_restore()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Commit(Frontier.Instance);
        provider.Restore(Snapshot.Empty);
    }

    private WorldState BuildStorageProvider(Context ctx)
    {
        return ctx.StateProvider;
    }

    [TestCase(-1)]
    [TestCase(0)]
    [TestCase(1)]
    [TestCase(2)]
    public void Same_address_same_index_different_values_restore(int snapshot)
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[3]);
        provider.Restore(Snapshot.EmptyPosition, snapshot, Snapshot.EmptyPosition);

        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).ToArray(), Is.EqualTo(_values[snapshot + 1]));
    }

    [Test]
    public void Keep_in_cache()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Commit(Frontier.Instance);
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Restore(Snapshot.EmptyPosition, -1, Snapshot.EmptyPosition);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Restore(Snapshot.EmptyPosition, -1, Snapshot.EmptyPosition);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Restore(Snapshot.EmptyPosition, -1, Snapshot.EmptyPosition);
        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).ToArray(), Is.EqualTo(_values[1]));
    }

    [TestCase(-1)]
    [TestCase(0)]
    [TestCase(1)]
    [TestCase(2)]
    public void Same_address_different_index(int snapshot)
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 2), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 3), _values[3]);
        provider.Restore(Snapshot.EmptyPosition, snapshot, Snapshot.EmptyPosition);

        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).ToArray(), Is.EqualTo(_values[Math.Min(snapshot + 1, 1)]));
    }

    [Test]
    public void Commit_restore()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 2), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 3), _values[3]);
        provider.Commit(Frontier.Instance);
        provider.Set(new StorageCell(ctx.Address2, 1), _values[4]);
        provider.Set(new StorageCell(ctx.Address2, 2), _values[5]);
        provider.Set(new StorageCell(ctx.Address2, 3), _values[6]);
        provider.Commit(Frontier.Instance);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[7]);
        provider.Set(new StorageCell(ctx.Address1, 2), _values[8]);
        provider.Set(new StorageCell(ctx.Address1, 3), _values[9]);
        provider.Commit(Frontier.Instance);
        provider.Set(new StorageCell(ctx.Address2, 1), _values[10]);
        provider.Set(new StorageCell(ctx.Address2, 2), _values[11]);
        provider.Set(new StorageCell(ctx.Address2, 3), _values[12]);
        provider.Commit(Frontier.Instance);
        provider.Restore(Snapshot.Empty);

        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).ToArray(), Is.EqualTo(_values[7]));
        Assert.That(provider.Get(new StorageCell(ctx.Address1, 2)).ToArray(), Is.EqualTo(_values[8]));
        Assert.That(provider.Get(new StorageCell(ctx.Address1, 3)).ToArray(), Is.EqualTo(_values[9]));
        Assert.That(provider.Get(new StorageCell(ctx.Address2, 1)).ToArray(), Is.EqualTo(_values[10]));
        Assert.That(provider.Get(new StorageCell(ctx.Address2, 2)).ToArray(), Is.EqualTo(_values[11]));
        Assert.That(provider.Get(new StorageCell(ctx.Address2, 3)).ToArray(), Is.EqualTo(_values[12]));
    }

    [Test]
    public void Commit_no_changes()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 2), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 3), _values[3]);
        provider.Restore(Snapshot.Empty);
        provider.Commit(Frontier.Instance);

        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).IsZero(), Is.True);
    }

    [Test]
    public void Commit_no_changes_2()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[3]);
        provider.Restore(Snapshot.EmptyPosition, 2, Snapshot.EmptyPosition);
        provider.Restore(Snapshot.EmptyPosition, 1, Snapshot.EmptyPosition);
        provider.Restore(Snapshot.EmptyPosition, 0, Snapshot.EmptyPosition);
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[3]);
        provider.Restore(Snapshot.EmptyPosition, -1, Snapshot.EmptyPosition);
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Get(new StorageCell(ctx.Address1, 1));
        provider.Commit(Frontier.Instance);

        Assert.That(provider.Get(new StorageCell(ctx.Address1, 1)).IsZero(), Is.True);
    }

    [Test]
    public void Commit_trees_clear_caches_get_previous_root()
    {
        Context ctx = new(setInitialState: false);
        // block 1
        Hash256 stateRoot;
        WorldState storageProvider = BuildStorageProvider(ctx);
        using (var _ = storageProvider.BeginScope(IWorldState.PreGenesis))
        {
            storageProvider.CreateAccount(ctx.Address1, 0);
            storageProvider.CreateAccount(ctx.Address2, 0);
            storageProvider.Commit(Frontier.Instance);
            storageProvider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
            storageProvider.Commit(Frontier.Instance);
            storageProvider.Commit(Frontier.Instance);
            storageProvider.CommitTree(0);
            stateRoot = ctx.StateProvider.StateRoot;
        }
        BlockHeader newBase = Build.A.BlockHeader.WithStateRoot(stateRoot).TestObject;

        // block 2
        using (var _ = storageProvider.BeginScope(newBase))
        {
            storageProvider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
            storageProvider.Commit(Frontier.Instance);
            storageProvider.CommitTree(0);
        }

        using (var _ = storageProvider.BeginScope(newBase))
        {
            storageProvider.AccountExists(ctx.Address1).Should().BeTrue();

            byte[] valueAfter = storageProvider.Get(new StorageCell(ctx.Address1, 1)).ToArray();

            Assert.That(valueAfter, Is.EqualTo(_values[1]));
        }
    }

    [Test]
    public void Can_commit_when_exactly_at_capacity_regression()
    {
        Context ctx = new();
        // block 1
        WorldState storageProvider = BuildStorageProvider(ctx);
        for (int i = 0; i < Resettable.StartCapacity; i++)
        {
            storageProvider.Set(new StorageCell(ctx.Address1, 1), _values[i % 2]);
        }

        storageProvider.Commit(Frontier.Instance);
        ctx.StateProvider.Commit(Frontier.Instance);

        byte[] valueAfter = storageProvider.Get(new StorageCell(ctx.Address1, 1)).ToArray();
        Assert.That(valueAfter, Is.EqualTo(_values[(Resettable.StartCapacity + 1) % 2]));
    }

    /// <summary>
    /// Transient storage should be zero if uninitialized
    /// </summary>
    [Test]
    public void Can_tload_uninitialized_locations()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        // Should be 0 if not set
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 1)).IsZero(), Is.True);

        // Should be 0 if loading from the same contract but different index
        provider.SetTransientState(new StorageCell(ctx.Address1, 2), _values[1]);
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 1)).IsZero(), Is.True);

        // Should be 0 if loading from the same index but different contract
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address2, 1)).IsZero(), Is.True);
    }

    /// <summary>
    /// Simple transient storage test
    /// </summary>
    [Test]
    public void Can_tload_after_tstore()
    {
        Context ctx = new Context();
        WorldState provider = BuildStorageProvider(ctx);

        provider.SetTransientState(new StorageCell(ctx.Address1, 2), _values[1]);
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 2)).ToArray(), Is.EqualTo(_values[1]));
    }

    /// <summary>
    /// Transient storage can be updated and restored
    /// </summary>
    /// <param name="snapshot">Snapshot to restore to</param>
    [TestCase(-1)]
    [TestCase(0)]
    [TestCase(1)]
    [TestCase(2)]
    public void Tload_same_address_same_index_different_values_restore(int snapshot)
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        Snapshot[] snapshots = new Snapshot[4];
        snapshots[0] = provider.TakeSnapshot();
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[1]);
        snapshots[1] = provider.TakeSnapshot();
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[2]);
        snapshots[2] = provider.TakeSnapshot();
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[3]);
        snapshots[3] = provider.TakeSnapshot();

        Assert.That(snapshot, Is.EqualTo(snapshots[snapshot + 1].StorageSnapshot.TransientStorageSnapshot));
        // Persistent storage is unimpacted by transient storage
        Assert.That(snapshots[snapshot + 1].StorageSnapshot.PersistentStorageSnapshot, Is.EqualTo(-1));

        provider.Restore(snapshots[snapshot + 1]);

        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 1)).ToArray(), Is.EqualTo(_values[snapshot + 1]));
    }

    /// <summary>
    /// Commit will reset transient state
    /// </summary>
    [Test]
    public void Commit_resets_transient_state()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);

        provider.SetTransientState(new StorageCell(ctx.Address1, 2), _values[1]);
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 2)).ToArray(), Is.EqualTo(_values[1]));

        provider.Commit(Frontier.Instance);
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 2)).IsZero(), Is.True);
    }

    /// <summary>
    /// Reset will reset transient state
    /// </summary>
    [Test]
    public void Reset_resets_transient_state()
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);

        provider.SetTransientState(new StorageCell(ctx.Address1, 2), _values[1]);
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 2)).ToArray(), Is.EqualTo(_values[1]));

        provider.Reset();
        Assert.That(provider.GetTransientState(new StorageCell(ctx.Address1, 2)).IsZero(), Is.True);
    }

    /// <summary>
    /// Transient state does not impact persistent state
    /// </summary>
    /// <param name="snapshot">Snapshot to restore to</param>
    [TestCase(-1)]
    [TestCase(0)]
    [TestCase(1)]
    [TestCase(2)]
    public void Transient_state_restores_independent_of_persistent_state(int snapshot)
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        Snapshot[] snapshots = new Snapshot[4];

        // No updates
        snapshots[0] = provider.TakeSnapshot();

        // Only update transient
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[1]);
        snapshots[1] = provider.TakeSnapshot();

        // Update both
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.Set(new StorageCell(ctx.Address1, 1), _values[9]);
        snapshots[2] = provider.TakeSnapshot();

        // Only update persistent
        provider.Set(new StorageCell(ctx.Address1, 1), _values[8]);
        snapshots[3] = provider.TakeSnapshot();

        provider.Restore(snapshots[snapshot + 1]);

        // Since we didn't update transient on the 3rd snapshot
        if (snapshot == 2)
        {
            snapshot--;
        }
        snapshots[0].StorageSnapshot.Should().BeEquivalentTo(Snapshot.Storage.Empty);
        snapshots[1].StorageSnapshot.Should().BeEquivalentTo(new Snapshot.Storage(Snapshot.EmptyPosition, 0));
        snapshots[2].StorageSnapshot.Should().BeEquivalentTo(new Snapshot.Storage(0, 1));
        snapshots[3].StorageSnapshot.Should().BeEquivalentTo(new Snapshot.Storage(1, 1));

        _values[snapshot + 1].Should().BeEquivalentTo(provider.GetTransientState(new StorageCell(ctx.Address1, 1)).ToArray());
    }

    /// <summary>
    /// Persistent state does not impact transient state
    /// </summary>
    /// <param name="snapshot">Snapshot to restore to</param>
    [TestCase(-1)]
    [TestCase(0)]
    [TestCase(1)]
    [TestCase(2)]
    public void Persistent_state_restores_independent_of_transient_state(int snapshot)
    {
        Context ctx = new();
        WorldState provider = BuildStorageProvider(ctx);
        Snapshot[] snapshots = new Snapshot[4];

        // No updates
        snapshots[0] = (provider).TakeSnapshot();

        // Only update persistent
        provider.Set(new StorageCell(ctx.Address1, 1), _values[1]);
        snapshots[1] = (provider).TakeSnapshot();

        // Update both
        provider.Set(new StorageCell(ctx.Address1, 1), _values[2]);
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[9]);
        snapshots[2] = (provider).TakeSnapshot();

        // Only update transient
        provider.SetTransientState(new StorageCell(ctx.Address1, 1), _values[8]);
        snapshots[3] = (provider).TakeSnapshot();

        provider.Restore(snapshots[snapshot + 1]);

        // Since we didn't update persistent on the 3rd snapshot
        if (snapshot == 2)
        {
            snapshot--;
        }

        snapshots.Should().Equal(
            Snapshot.Empty,
            new Snapshot(new Snapshot.Storage(0, Snapshot.EmptyPosition), Snapshot.EmptyPosition),
            new Snapshot(new Snapshot.Storage(1, 0), Snapshot.EmptyPosition),
            new Snapshot(new Snapshot.Storage(1, 1), Snapshot.EmptyPosition)
        );

        _values[snapshot + 1].Should().BeEquivalentTo(provider.Get(new StorageCell(ctx.Address1, 1)).ToArray());
    }

    /// <summary>
    /// Reset will reset transient state
    /// </summary>
    [Test]
    public void Selfdestruct_clears_cache()
    {
        PreBlockCaches preBlockCaches = new PreBlockCaches();
        Context ctx = new(preBlockCaches);
        WorldState provider = BuildStorageProvider(ctx);
        StorageCell accessedStorageCell = new StorageCell(TestItem.AddressA, 1);
        StorageCell nonAccessedStorageCell = new StorageCell(TestItem.AddressA, 2);
        preBlockCaches.StorageCache[accessedStorageCell] = [1, 2, 3];
        provider.Get(accessedStorageCell);
        provider.Commit(Paris.Instance);
        provider.ClearStorage(TestItem.AddressA);
        provider.Get(accessedStorageCell).ToArray().Should().BeEquivalentTo(StorageTree.ZeroBytes);
        provider.Get(nonAccessedStorageCell).ToArray().Should().BeEquivalentTo(StorageTree.ZeroBytes);
    }

    [Test]
    public void Selfdestruct_persist_between_commit()
    {
        PreBlockCaches preBlockCaches = new PreBlockCaches();
        Context ctx = new(preBlockCaches);
        StorageCell accessedStorageCell = new StorageCell(TestItem.AddressA, 1);
        preBlockCaches.StorageCache[accessedStorageCell] = [1, 2, 3];

        WorldState provider = BuildStorageProvider(ctx);
        provider.Get(accessedStorageCell).ToArray().Should().BeEquivalentTo([1, 2, 3]);
        provider.ClearStorage(TestItem.AddressA);
        provider.Commit(Paris.Instance);
        provider.Get(accessedStorageCell).ToArray().Should().BeEquivalentTo(StorageTree.ZeroBytes);
    }

    [TestCase(2)]
    [TestCase(1000)]
    public void Set_empty_value_for_storage_cell_without_read_clears_data(int numItems)
    {
        IWorldState worldState = new WorldState(TestTrieStoreFactory.Build(new MemDb(), LimboLogs.Instance), Substitute.For<IDb>(), LogManager);

        using var disposable = worldState.BeginScope(IWorldState.PreGenesis);
        worldState.CreateAccount(TestItem.AddressA, 1);
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(0);
        Hash256 emptyHash = worldState.StateRoot;

        for (int i = 0; i < numItems; i++)
        {
            UInt256 asUInt256 = (UInt256)(i + 1);
            worldState.Set(new StorageCell(TestItem.AddressA, (UInt256)i), asUInt256.ToBigEndian());
        }
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(1);

        var fullHash = worldState.StateRoot;
        fullHash.Should().NotBe(emptyHash);

        for (int i = 0; i < numItems; i++)
        {
            worldState.Set(new StorageCell(TestItem.AddressA, (UInt256)i), [0]);
        }
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(2);

        var clearedHash = worldState.StateRoot;

        clearedHash.Should().Be(emptyHash);
    }

    [Test]
    public void Set_empty_value_for_storage_cell_with_read_clears_data()
    {
        IWorldState worldState = new WorldState(TestTrieStoreFactory.Build(new MemDb(), LimboLogs.Instance), Substitute.For<IDb>(), LogManager);

        using var disposable = worldState.BeginScope(IWorldState.PreGenesis);
        worldState.CreateAccount(TestItem.AddressA, 1);
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(0);
        Hash256 emptyHash = worldState.StateRoot;

        worldState.Set(new StorageCell(TestItem.AddressA, 1), _values[11]);
        worldState.Set(new StorageCell(TestItem.AddressA, 2), _values[12]);
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(1);

        var fullHash = worldState.StateRoot;
        fullHash.Should().NotBe(emptyHash);

        worldState.Get(new StorageCell(TestItem.AddressA, 1));
        worldState.Get(new StorageCell(TestItem.AddressA, 2));
        worldState.Set(new StorageCell(TestItem.AddressA, 1), [0]);
        worldState.Set(new StorageCell(TestItem.AddressA, 2), [0]);
        worldState.Commit(Prague.Instance);
        worldState.CommitTree(2);

        var clearedHash = worldState.StateRoot;

        clearedHash.Should().Be(emptyHash);
    }

    private class Context
    {
        public WorldState StateProvider { get; }

        public readonly Address Address1 = new(Keccak.Compute("1"));
        public readonly Address Address2 = new(Keccak.Compute("2"));

        public Context(PreBlockCaches preBlockCaches = null, bool setInitialState = true)
        {
            StateProvider = new WorldState(TestTrieStoreFactory.Build(new MemDb(), LimboLogs.Instance), Substitute.For<IDb>(), LogManager, preBlockCaches);
            if (setInitialState)
            {
                StateProvider.BeginScope(IWorldState.PreGenesis);
                StateProvider.CreateAccount(Address1, 0);
                StateProvider.CreateAccount(Address2, 0);
                StateProvider.Commit(Frontier.Instance);
            }
        }
    }
}
